import numpy as np
import torch
import torch.nn as nn

from geotransformer.modules.ops import pairwise_distance
from geotransformer.modules.transformer import SinusoidalPositionalEmbedding, RPEConditionalTransformer
from geotransformer.modules.transformer.LoRA import LoRALayer



# 
class GeometricStructureEmbedding(nn.Module):
    def __init__(self, hidden_dim, sigma_d, sigma_a, angle_k, reduction_a='max', sigma_color=None, sigma_hd = None, sigma_gs = None):  # tjy
        super(GeometricStructureEmbedding, self).__init__()
        self.sigma_d = sigma_d
        self.sigma_a = sigma_a
        self.factor_a = 180.0 / (self.sigma_a * np.pi)
        self.angle_k = angle_k
        self.sigma_gs = sigma_gs

        self.embedding = SinusoidalPositionalEmbedding(hidden_dim)
        # self.proj_d = nn.Linear(hidden_dim, hidden_dim)
        # self.proj_a = nn.Linear(hidden_dim, hidden_dim)
        self.proj_d = LoRALayer(hidden_dim, hidden_dim)
        self.proj_a = LoRALayer(hidden_dim, hidden_dim)



        # 
        if sigma_hd is not None:
            self.sigma_hd  = sigma_hd
            self.proj_hd = LoRALayer(hidden_dim, hidden_dim)
            self.add_mlp = LoRALayer(2*hidden_dim, hidden_dim)
        #

        #
        if sigma_color is not None:
            self.sigma_color = sigma_color
            self.proj_color = LoRALayer(hidden_dim, hidden_dim)

        self.reduction_a = reduction_a
        if self.reduction_a not in ['max', 'mean']:
            raise ValueError(f'Unsupported reduction mode: {self.reduction_a}.')


    @torch.no_grad()
    def get_embedding_indices(self, points, hsv = None, gs_params=None):
        r"""Compute the indices of pair-wise distance embedding and triplet-wise angular embedding.

        Args:
            points: torch.Tensor (B, N, 3), input point cloud

        Returns:
            d_indices: torch.FloatTensor (B, N, N), distance embedding indices
            a_indices: torch.FloatTensor (B, N, N, k), angular embedding indices
        """
        batch_size, num_point, _ = points.shape

        dist_map = torch.sqrt(pairwise_distance(points, points))  # (B, N, N) 
        d_indices = dist_map / self.sigma_d   #

        # ref_vectors and  anc_vectors 
        k = self.angle_k
        knn_indices = dist_map.topk(k=k + 1, dim=2, largest=False)[1][:, :, 1:]  # (B, N, k)
        knn_indices = knn_indices.unsqueeze(3).expand(batch_size, num_point, k, 3)  # (B, N, k, 3)  
        expanded_points = points.unsqueeze(1).expand(batch_size, num_point, num_point, 3)  # (B, N, N, 3)
        knn_points = torch.gather(expanded_points, dim=2, index=knn_indices)  # (B, N, k, 3)   
        ref_vectors = knn_points - points.unsqueeze(2)  # (B, N, k, 3)
        anc_vectors = points.unsqueeze(1) - points.unsqueeze(2)  # (B, N, N, 3)  
        ref_vectors = ref_vectors.unsqueeze(2).expand(batch_size, num_point, num_point, k, 3)  # (B, N, N, k, 3)
        anc_vectors = anc_vectors.unsqueeze(3).expand(batch_size, num_point, num_point, k, 3)  # (B, N, N, k, 3)
        sin_values = torch.linalg.norm(torch.cross(ref_vectors, anc_vectors, dim=-1), dim=-1)  # (B, N, N, k)  
        cos_values = torch.sum(ref_vectors * anc_vectors, dim=-1)  # (B, N, N, k)
        angles = torch.atan2(sin_values, cos_values)  # (B, N, N, k)   
        a_indices = angles * self.factor_a                             

        #color enbedding
        hd_indices = None
        if hsv is not None:
            h = hsv[:, :, 0].unsqueeze(2)
            # print("h: ", h.shape)
            h_t = torch.transpose(h, 2, 1)
            # print("h_t: ", h_t.shape)
            delta_h = torch.abs(h - h_t)
            h_plus_d = delta_h * dist_map
            hd_indices = h_plus_d/self.sigma_hd
            # print("hd_indices: ", hd_indices.shape)


        # # Color-based embedding (if hsv is provided)
        # color_indices = None
        # if hsv is not None:
        #     h = hsv[:, :, 0].unsqueeze(2)
        #     h_t = torch.transpose(h, 2, 1)
        #     delta_h = torch.abs(h - h_t)
        #     color_indices = delta_h / self.sigma_color

        gs_indices = None
        if gs_params is not None:
            # 
            mean, cov = gs_params[:, :, :3], gs_params[:, :, 3:]
            
            #
            m = mean[:, :, 0].unsqueeze(2)  # [B, N, 1] - 
            m_t = torch.transpose(m, 2, 1)  # [B, 1, N] - 

            # 
            c = cov[:, :, 0].unsqueeze(2)  # [B, N, 1] - 
            c_t = torch.transpose(c, 2, 1)  # [B, 1, N] - 

            # 
            delta_m = torch.abs(m - m_t)  # [B, N, N] - 
            # 

            # 
            delta_c = torch.abs(c - c_t)  # [B, N, N] -
            # print("delta_c: ", delta_c.shape)

            # 
            gs_indices = (delta_m + delta_c) * dist_map  # [B, N, N] - 
            # gs_indices = gs_indices / self.sigma_gs  # 
            gs_indices = gs_indices / self.sigma_gs
            # print("gs_indices: ", gs_indices.shape)
            

        # end 

        return d_indices, a_indices, hd_indices, gs_indices

    def forward(self, points, hsv = None, gs_params=None):  # 
        d_indices, a_indices, hd_indices, gs_indices = self.get_embedding_indices(points, hsv = hsv, gs_params=gs_params)  #

        d_embeddings = self.embedding(d_indices)
        d_embeddings = self.proj_d(d_embeddings)   #

        a_embeddings = self.embedding(a_indices)
        a_embeddings = self.proj_a(a_embeddings)   #
        if self.reduction_a == 'max':
            a_embeddings = a_embeddings.max(dim=3)[0]
        else:
            a_embeddings = a_embeddings.mean(dim=3)

        embeddings = d_embeddings + a_embeddings   #

        #
        if hsv is not None:
            hd_embeddings = self.embedding(hd_indices)
            hd_embeddings = self.proj_hd(hd_embeddings)
            # print("11: ", embeddings.shape, "hd: ", hd_embeddings.shape)
            embeddings = embeddings + hd_embeddings

        # # Add color embeddings if available
        if gs_indices is not None:
            gs_embeddings = self.embedding(gs_indices)
            gs_embeddings = self.proj_color(gs_embeddings)
            # print("11: ", embeddings.shape, "gs: ", gs_embeddings.shape)
            embeddings = embeddings + gs_embeddings 
        # end

        return embeddings

#
class GeometricTransformer(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_dim,
        num_heads,
        blocks,
        sigma_d,
        sigma_a,
        angle_k,
        sigma_gs,
        sigma_hd = None,
        sigma_color=None,
        dropout=None,
        activation_fn='ReLU',
        reduction_a='max',
    ):
        r"""Geometric Transformer (GeoTransformer).

        Args:
            input_dim: input feature dimension
            output_dim: output feature dimension
            hidden_dim: hidden feature dimension
            num_heads: number of head in transformer
            blocks: list of 'self' or 'cross'
            sigma_d: temperature of distance
            sigma_a: temperature of angles
            angle_k: number of nearest neighbors for angular embedding
            activation_fn: activation function
            reduction_a: reduction mode of angular embedding ['max', 'mean'] 
        """
        super(GeometricTransformer, self).__init__()

        # 
        self.embedding = GeometricStructureEmbedding(hidden_dim, sigma_d, sigma_a, angle_k, sigma_color=sigma_color, reduction_a=reduction_a, sigma_gs = sigma_gs, sigma_hd = sigma_hd)


        # print("input_dim: ", input_dim, "hidden_dim: ", hidden_dim)
        self.in_proj = nn.Linear(input_dim, hidden_dim)

        # ['self', 'cross', 'self', 'cross', 'self', 'cross']  blocks

        self.transformer = RPEConditionalTransformer(
            blocks, hidden_dim, num_heads, dropout=dropout, activation_fn=activation_fn
        )

        self.out_proj = nn.Linear(hidden_dim, output_dim)

    #
    def forward(
        self,
        ref_points,
        src_points,
        ref_feats,
        src_feats,
        ref_masks=None,
        src_masks=None,
        ref_colors = None,
        src_colors = None,
        ref_gs = None,
        src_gs = None,
    ):
        r"""Geometric Transformer
 
        Args:
            ref_points (Tensor): (B, N, 3)
            src_points (Tensor): (B, M, 3)
            ref_feats (Tensor): (B, N, C)
            src_feats (Tensor): (B, M, C)
            ref_masks (Optional[BoolTensor]): (B, N)
            src_masks (Optional[BoolTensor]): (B, M)

        Returns:
            ref_feats: torch.Tensor (B, N, C)
            src_feats: torch.Tensor (B, M, C)
        """

        # print("ref_points: ", ref_points.shape)
        # print("src_points_geo: ", src_points.shape)   #
        
        #
        # src_embeddings = self.embedding(src_points, hsv = src_colors)
        
        ref_embeddings = self.embedding(ref_points, hsv = ref_colors, gs_params=ref_gs)
        src_embeddings = self.embedding(src_points, hsv = src_colors, gs_params=src_gs)

        ref_feats = self.in_proj(ref_feats)
        src_feats = self.in_proj(src_feats)

        ref_feats, src_feats = self.transformer(
            ref_feats,
            src_feats,
            ref_embeddings,
            src_embeddings,
            masks0=ref_masks,
            masks1=src_masks,
        )

        # 
        ref_feats = self.out_proj(ref_feats)
        src_feats = self.out_proj(src_feats)

        return ref_feats, src_feats
    # end 

